from math import exp
import os
from PIL.Image import FASTOCTREE
import numpy as np
import yaml
from torch.utils.data import Dataset

from image_synthesis.data.utils.image_path_dataset import ImagePaths
import image_synthesis.data.utils.imagenet_utils as imagenet_utils
from image_synthesis.utils.misc import instantiate_from_config

from torchvision.datasets import ImageFolder

class MyImageFolder(ImageFolder):
    """
    This class can be used to load images when given a file contain the list of image paths
    """
    def __init__(self, 
                 root='',
                 im_preprocessor_config={
                     'target': 'image_synthesis.data.utils.image_preprocessor.SimplePreprocessor',
                     'params':{
                        'size': 256,
                        'random_crop': True,
                        'horizon_flip': True
                        }
                 }):
        super().__init__(root=root)

        # get preprocessor
        self.preprocessor = instantiate_from_config(im_preprocessor_config)

    def filter_images(self, expected=None, unexpected=None, basename=False):
        if expected is None and unexpected is None:
            return 
        else:
            if expected is None:
                expected = []
            else:
                if isinstance(expected, str):
                    expected = [expected] if ',' not in expected else expected.split(',')
                assert isinstance(expected, list)
            
            if unexpected is None:
                unexpected = []
            else:
                if isinstance(unexpected, str):
                    unexpected = [unexpected] if ',' not in unexpected else unexpected.split(',')
                assert isinstance(unexpected, list)
                imgs = []
                for im_path, class_id in self.imgs:
                    need_check = os.path.basename(im_path) if basename else im_path
                    valid = True 
                    for e in expected:
                        if e not in need_check:
                            valid = False
                            break
                
                    for e in unexpected:
                        if e in need_check:
                            valid = False
                            break 
                    if valid:
                        items = im_path, class_id
                        imgs.append(items)
            self.imgs = imgs

    def folder_name_to_im_path(self):
        data = {}
        for im_path, class_id in self.imgs:
            folder_name = self.classes[class_id]
            if folder_name not in data:
                data[folder_name] = []
            data[folder_name].append(im_path)
        
        return data
